{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "from model.blip2_llama_inference import Blip2Llama\n",
    "from model.unimol import SimpleUniMolModel\n",
    "from rdkit import Chem\n",
    "from rdkit.Chem import AllChem\n",
    "from unicore.data import Dictionary\n",
    "import numpy as np\n",
    "import torch\n",
    "from scipy.spatial import distance_matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "smiles = 'CC(=O)OC(CC(=O)[O-])C[N+](C)(C)C'\n",
    "# without sdf file, 3D coordinates are generated using RDKit\n",
    "mol_file = 'Conformer3D_COMPOUND_CID_1.sdf'\n",
    "# mol_file = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "tensor_type = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Namespace(bert_name='all_checkpoints/scibert_scivocab_uncased', cross_attention_freq=2, enable_flash=False, llm_model='all_checkpoints/llama-2-7b-hf', lora_alpha=32, lora_dropout=0.1, lora_path='all_checkpoints/generalist/generalist.ckpt', lora_r=8, num_query_token=8, unimol_activation_dropout=0.0, unimol_activation_fn='gelu', unimol_attention_dropout=0.1, unimol_delta_pair_repr_norm_loss=-1.0, unimol_dropout=0.1, unimol_emb_dropout=0.1, unimol_encoder_attention_heads=64, unimol_encoder_embed_dim=512, unimol_encoder_ffn_embed_dim=2048, unimol_encoder_layers=15, unimol_max_atoms=256, unimol_max_seq_len=512)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import argparse\n",
    "def get_args():\n",
    "    parser = argparse.ArgumentParser()\n",
    "    ### models\n",
    "    parser.add_argument('--bert_name', type=str, default='all_checkpoints/scibert_scivocab_uncased')\n",
    "    parser.add_argument('--llm_model', type=str, default='all_checkpoints/llama-2-7b-hf')\n",
    "    \n",
    "    ### flash attention\n",
    "    parser.add_argument('--enable_flash', action='store_false', default=False)\n",
    "\n",
    "    ### lora settings\n",
    "    parser.add_argument('--lora_r', type=int, default=8)\n",
    "    parser.add_argument('--lora_alpha', type=int, default=32)\n",
    "    parser.add_argument('--lora_dropout', type=int, default=0.1)\n",
    "    parser.add_argument('--lora_path', type=str, default='all_checkpoints/generalist/generalist.ckpt')\n",
    "\n",
    "    ### q-former settings\n",
    "    parser.add_argument('--cross_attention_freq', type=int, default=2)\n",
    "    parser.add_argument('--num_query_token', type=int, default=8)\n",
    "\n",
    "    parser = SimpleUniMolModel.add_args(parser)\n",
    "\n",
    "    args, unknown = parser.parse_known_args()\n",
    "    return args\n",
    "args = get_args()\n",
    "args"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|██████████| 2/2 [00:09<00:00,  4.90s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loaded model from all_checkpoints/generalist/generalist.ckpt\n"
     ]
    }
   ],
   "source": [
    "model = Blip2Llama(args).to(tensor_type)\n",
    "device = torch.device(\"cpu\")\n",
    "model.to(device)\n",
    "tokenizer = model.llm_tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "def smiles2graph(smiles):\n",
    "    mol = Chem.MolFromSmiles(smiles)\n",
    "    mol = AllChem.AddHs(mol)\n",
    "    atoms = [atom.GetSymbol() for atom in mol.GetAtoms()]\n",
    "    if (np.asarray(atoms) == 'H').all():\n",
    "        return None\n",
    "    coordinate_list = []\n",
    "    res = AllChem.EmbedMolecule(mol)\n",
    "    if res == 0:\n",
    "        try:\n",
    "            AllChem.MMFFOptimizeMolecule(mol)\n",
    "        except:\n",
    "            pass\n",
    "        coordinates = mol.GetConformer().GetPositions()\n",
    "    elif res == -1:\n",
    "        mol_tmp = Chem.MolFromSmiles(smiles)\n",
    "        AllChem.EmbedMolecule(mol_tmp, maxAttempts=5000)\n",
    "        mol_tmp = AllChem.AddHs(mol_tmp, addCoords=True)\n",
    "        try:\n",
    "            AllChem.MMFFOptimizeMolecule(mol_tmp)\n",
    "        except:\n",
    "            pass\n",
    "        coordinates = mol_tmp.GetConformer().GetPositions()\n",
    "    coordinates = coordinates.astype(np.float32)\n",
    "    assert len(atoms) == len(coordinates), \"coordinates shape is not align with {}\".format(smiles)\n",
    "    assert coordinates.shape[1] == 3\n",
    "    \n",
    "    atoms = np.asarray(atoms)\n",
    "    ## remove the hydrogen\n",
    "    mask_hydrogen = atoms != \"H\"\n",
    "    if sum(mask_hydrogen) > 0:\n",
    "        atoms = atoms[mask_hydrogen]\n",
    "        coordinates = coordinates[mask_hydrogen]\n",
    "\n",
    "    ## atom vectors\n",
    "    dictionary = Dictionary.load('data_provider/unimol_dict.txt')\n",
    "    dictionary.add_symbol(\"[MASK]\", is_special=True)\n",
    "    atom_vec = torch.from_numpy(dictionary.vec_index(atoms)).long()\n",
    "\n",
    "    ## normalize coordinates:\n",
    "    coordinates = coordinates - coordinates.mean(axis=0)\n",
    "\n",
    "    ## add_special_token:\n",
    "    atom_vec = torch.cat([torch.LongTensor([dictionary.bos()]), atom_vec, torch.LongTensor([dictionary.eos()])])\n",
    "    coordinates = np.concatenate([np.zeros((1, 3)), coordinates, np.zeros((1, 3))], axis=0)\n",
    "    \n",
    "    ## obtain edge types; which is defined as the combination of two atom types\n",
    "    edge_type = atom_vec.view(-1, 1) * len(dictionary) + atom_vec.view(1, -1)\n",
    "    dist = distance_matrix(coordinates, coordinates).astype(np.float32)\n",
    "    coordinates, dist = torch.from_numpy(coordinates), torch.from_numpy(dist)\n",
    "\n",
    "    return atom_vec, dist, edge_type\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sdf2graph(sdf_file):\n",
    "    molecules = Chem.SDMolSupplier(sdf_file)\n",
    "    for molecule in molecules:\n",
    "\n",
    "        # Get 3D Conformer if available\n",
    "        conformers = []\n",
    "        num_conformers = molecule.GetNumConformers()\n",
    "        for i in range(num_conformers):\n",
    "            conformer = molecule.GetConformer(i)\n",
    "            conformers.append(conformer.GetPositions())\n",
    "\n",
    "        # Get atoms and coordinates\n",
    "        atoms = [atom.GetSymbol() for atom in molecule.GetAtoms()]\n",
    "        coordinates = conformers[0].astype(np.float32)\n",
    "\n",
    "        assert len(atoms) == len(coordinates), \"coordinates shape is not align with {}\".format(smiles)\n",
    "        assert coordinates.shape[1] == 3\n",
    "\n",
    "        atoms = np.asarray(atoms)\n",
    "        ## remove the hydrogen\n",
    "        mask_hydrogen = atoms != \"H\"\n",
    "        if sum(mask_hydrogen) > 0:\n",
    "            atoms = atoms[mask_hydrogen]\n",
    "            coordinates = coordinates[mask_hydrogen]\n",
    "\n",
    "        ## atom vectors\n",
    "        dictionary = Dictionary.load('data_provider/unimol_dict.txt')\n",
    "        dictionary.add_symbol(\"[MASK]\", is_special=True)\n",
    "        atom_vec = torch.from_numpy(dictionary.vec_index(atoms)).long()\n",
    "\n",
    "        ## normalize coordinates:\n",
    "        coordinates = coordinates - coordinates.mean(axis=0)\n",
    "\n",
    "        ## add_special_token:\n",
    "        atom_vec = torch.cat([torch.LongTensor([dictionary.bos()]), atom_vec, torch.LongTensor([dictionary.eos()])])\n",
    "        coordinates = np.concatenate([np.zeros((1, 3)), coordinates, np.zeros((1, 3))], axis=0)\n",
    "        \n",
    "        ## obtain edge types; which is defined as the combination of two atom types\n",
    "        edge_type = atom_vec.view(-1, 1) * len(dictionary) + atom_vec.view(1, -1)\n",
    "        dist = distance_matrix(coordinates, coordinates).astype(np.float32)\n",
    "        coordinates, dist = torch.from_numpy(coordinates), torch.from_numpy(dist)\n",
    "\n",
    "        return atom_vec, dist, edge_type\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_3d_graph(smiles=None, sdf_file=None):\n",
    "    if sdf_file is not None:\n",
    "        d3_graph = sdf2graph(sdf_file)\n",
    "    elif smiles is not None:\n",
    "        d3_graph = smiles2graph(smiles)\n",
    "    else:\n",
    "        raise ValueError('smiles must be provided')\n",
    "    return d3_graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "def tokenize(tokenizer, text):\n",
    "    text_tokens = tokenizer(text,\n",
    "                            add_special_tokens=True,\n",
    "                            return_tensors='pt',\n",
    "                            return_attention_mask=True,\n",
    "                            return_token_type_ids=True)\n",
    "    is_mol_token = text_tokens.input_ids == tokenizer.mol_token_id\n",
    "    text_tokens['is_mol_token'] = is_mol_token\n",
    "    assert torch.sum(is_mol_token).item() == 8\n",
    "\n",
    "    return text_tokens\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_ids': tensor([[    2, 13866,   338,   385, 15278,   393, 16612,   263,  3414, 29892,\n",
       "          3300,  2859,   411,   385,  1881, 13206, 29883,  1297, 29889, 14350,\n",
       "           263,  2933,   393,  7128,  2486,  1614,  2167,   278,  2009, 29889,\n",
       "            13,  3379,  4080, 29901,   306,   817,   304,  1073,   278,  4522,\n",
       "         29925,   310,   445, 13206, 29883,  1297, 29892,  1033,   366,  3113,\n",
       "          3867,   372, 29973,   960, 17999, 29892,  3867,   385, 12678, 29889,\n",
       "          2538,  2818,   411,   278, 16259,   995,   871, 29889,    13,  4290,\n",
       "         13206, 29883,  1297, 29901, 19178, 29898, 29922, 29949, 29897, 20166,\n",
       "         29898,  4174, 29898, 29922, 29949,  9601, 29949, 29899,  2314, 29907,\n",
       "         29961, 29940, 29974,   850, 29907,  5033, 29907, 29897, 29907, 29871,\n",
       "         32001, 32001, 32001, 32001, 32001, 32001, 32001, 32001,   869,    13,\n",
       "          5103, 29901, 29871]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'is_mol_token': tensor([[False, False, False, False, False, False, False, False, False, False,\n",
       "         False, False, False, False, False, False, False, False, False, False,\n",
       "         False, False, False, False, False, False, False, False, False, False,\n",
       "         False, False, False, False, False, False, False, False, False, False,\n",
       "         False, False, False, False, False, False, False, False, False, False,\n",
       "         False, False, False, False, False, False, False, False, False, False,\n",
       "         False, False, False, False, False, False, False, False, False, False,\n",
       "         False, False, False, False, False, False, False, False, False, False,\n",
       "         False, False, False, False, False, False, False, False, False, False,\n",
       "         False, False, False, False, False, False, False, False, False, False,\n",
       "          True,  True,  True,  True,  True,  True,  True,  True, False, False,\n",
       "         False, False, False]])}"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "atom_vec, dist, edge_type = get_3d_graph(smiles=smiles, sdf_file=mol_file)\n",
    "atom_vec, dist, edge_type = atom_vec.unsqueeze(0), dist.unsqueeze(0).to(tensor_type), edge_type.unsqueeze(0)\n",
    "atom_vec, dist, edge_type = atom_vec.to(device), dist.to(device), edge_type.to(device)\n",
    "graph = (atom_vec, dist, edge_type)\n",
    "prompt = \"Below is an instruction that describes a task, paired with an input molecule. Write a response that appropriately completes the request.\\n\" \\\n",
    "         \"Instruction: {}\\n\" \\\n",
    "         \"Input molecule: {} <mol><mol><mol><mol><mol><mol><mol><mol>.\\n\" \\\n",
    "         \"Response: \"\n",
    "instruction = \"I need to know the LogP of this molecule, could you please provide it? If uncertain, provide an estimate. Respond with the numerical value only.\"\n",
    "input = prompt.format(instruction, smiles)\n",
    "input_tokens = tokenize(tokenizer, input)\n",
    "input_tokens.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['The LogP for the input molecule is 0.80.\\nThe LogP for O-acetylcarnitine is 0.80. O-acetylcarnitine is a metabolite found in or produced by Saccharomyces cerevisiae.\\nThe LogP for O-acetylcarnitine is 0.80. O-acetylcarnitine is a metabolite found in or produced by Escherichia coli (strain K12, MG1655).\\nThe']\n"
     ]
    }
   ],
   "source": [
    "output = model.generate(graph, input_tokens)\n",
    "print(output)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "3D-MoLM",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
